import sys
import time
import numpy as np
from OpenGL.GLU import gluOrtho2D, gluPerspective
from OpenGL.GL import GL_COMPUTE_SHADER
from OpenGL.GLUT import *
from OpenGL.GL.shaders import compileProgram, compileShader
import GPUtil

# -------------------------------
# GPU Compute Shader (Prime Check)
# -------------------------------
compute_shader_src = """
#version 430
layout(local_size_x = 256) in;

layout(std430, binding=0) buffer Results {
    uint primes[];
};

uint is_prime(uint n) {
    if (n < 2) return 0;
    if (n == 2) return 1;
    if (n % 2 == 0) return 0;
    uint r = uint(sqrt(float(n)));
    for (uint i = 3u; i <= r; i += 2u) {
        if (n % i == 0u) return 0;
    }
    return 1;
}

void main() {
    uint gid = gl_GlobalInvocationID.x;
    primes[gid] = is_prime(gid + 2u);
}
"""

# -------------------------------
# Lattice Shader (GLSL Vertex/Fragment)
# -------------------------------
vertex_shader_src = """
#version 330
layout(location=0) in vec2 pos;
void main() {
    gl_Position = vec4(pos, 0.0, 1.0);
}
"""

fragment_shader_src = """
#version 330
out vec4 fragColor;
uniform float glow;
void main() {
    fragColor = vec4(glow, glow*0.8, glow*0.4, 1.0);
}
"""

# -------------------------------
# Globals
# -------------------------------
window = None
lattice_width = 128
num_instances = 10000
num_instances_max = 5000000
num_instances_min = lattice_width**2
glow_phase = 0.0

# GPU telemetry
gpu_util = 0.0
gpu_mem = "0/0 MB"
gpu_clock = 0.0
primes_found = 0

# Auto-tune parameters
target_util = 0.85        # aim for 85% GPU load
target_margin = 0.05      # ±5%
scale_step = 5000         # step increase/decrease per frame

# OpenGL handles
compute_program = None
render_program = None
ssbo = None
vao = None

# -------------------------------
# OpenGL Initialization
# -------------------------------
def init_gl():
    global compute_program, render_program, ssbo, vao

    # Compile shaders
    compute_program = compileProgram(compileShader(compute_shader_src, GL_COMPUTE_SHADER))
    render_program = compileProgram(
        compileShader(vertex_shader_src, GL_VERTEX_SHADER),
        compileShader(fragment_shader_src, GL_FRAGMENT_SHADER)
    )

    # SSBO for prime results
    ssbo = glGenBuffers(1)
    glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo)
    glBufferData(GL_SHADER_STORAGE_BUFFER, num_instances_max * 4, None, GL_DYNAMIC_DRAW)
    glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, ssbo)
    glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0)

    # VAO for lattice
    vao = glGenVertexArrays(1)
    glBindVertexArray(vao)

# -------------------------------
# Update & Compute Primes
# -------------------------------
def compute_primes():
    global primes_found

    glUseProgram(compute_program)
    glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, ssbo)

    # Dispatch compute shader
    group_count = (num_instances + 255) // 256
    glDispatchCompute(group_count, 1, 1)
    glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT)

    # Read back results
    glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo)
    data = np.frombuffer(glGetBufferSubData(GL_SHADER_STORAGE_BUFFER, 0, num_instances*4), dtype=np.uint32)
    primes_found = int(np.sum(data))
    glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0)

# -------------------------------
# Draw Lattice
# -------------------------------
def draw_lattice():
    global glow_phase, num_instances, gpu_util, gpu_mem, gpu_clock

    glClear(GL_COLOR_BUFFER_BIT)
    glLoadIdentity()

    compute_primes()

    # Glow intensity proportional to primes found
    glow_phase += 0.05
    glow = (np.sin(glow_phase) * 0.5 + 0.5) * min(1.0, primes_found / num_instances)

    glUseProgram(render_program)
    loc = glGetUniformLocation(render_program, "glow")
    glUniform1f(loc, glow)

    glBindVertexArray(vao)
    glBegin(GL_POINTS)
    for i in range(num_instances):
        x = (i % lattice_width) / lattice_width * 2 - 1
        y = (i // lattice_width) / lattice_width * 2 - 1
        glVertex2f(x, y)
    glEnd()

    # GPU telemetry
    gpu = GPUtil.getGPUs()[0] if GPUtil.getGPUs() else None
    if gpu:
        gpu_util = gpu.load
        gpu_mem = f"{gpu.memoryUsed}/{gpu.memoryTotal} MB"
        gpu_clock = gpu.clock

        # Auto-tune num_instances to reach target_util
        if gpu_util < target_util - target_margin:
            num_instances = min(num_instances + scale_step, num_instances_max)
        elif gpu_util > target_util + target_margin:
            num_instances = max(num_instances - scale_step, num_instances_min)
    else:
        num_instances = min(num_instances + scale_step, num_instances_max)

    draw_hud()
    glutSwapBuffers()

# -------------------------------
# HUD Overlay
# -------------------------------
def draw_hud():
    glMatrixMode(GL_PROJECTION)
    glPushMatrix()
    glLoadIdentity()
    gluOrtho2D(0, 800, 0, 800)
    glMatrixMode(GL_MODELVIEW)
    glPushMatrix()
    glLoadIdentity()

    glColor3f(0.2, 1.0, 0.2)
    lines = [
        f"[HDGL] Instances: {num_instances}",
        f"GPU Load: {gpu_util*100:.1f}%",
        f"Mem: {gpu_mem}",
        f"Clock: {gpu_clock:.2f} MHz",
        f"[GRA] Primes found: {primes_found}"
    ]
    y = 770
    for line in lines:
        glRasterPos2f(10, y)
        for ch in line:
            glutBitmapCharacter(GLUT_BITMAP_9_BY_15, ord(ch))
        y -= 18

    glMatrixMode(GL_MODELVIEW)
    glPopMatrix()
    glMatrixMode(GL_PROJECTION)
    glPopMatrix()
    glMatrixMode(GL_MODELVIEW)

def update(value):
    glutPostRedisplay()
    glutTimerFunc(33, update, 0)

# -------------------------------
# Main
# -------------------------------
def main():
    global window
    glutInit(sys.argv)
    glutInitDisplayMode(GLUT_RGBA | GLUT_DOUBLE)
    glutInitWindowSize(800, 800)
    glutInitWindowPosition(0, 0)
    window = glutCreateWindow(b"HDGL + GPU Prime Solver Auto-Tune")
    init_gl()
    glutDisplayFunc(draw_lattice)
    glutTimerFunc(0, update, 0)
    glutMainLoop()

if __name__ == "__main__":
    main()
